import requests
from lxml import etree
import pymysql
import traceback
import time

# 关闭 https 相关的警告
requests.packages.urllib3.disable_warnings()

# 国家统计局 (National Bureau Of Statistics) 行政区划数据抓取的主URL
HOME_URL = "https://www.stats.gov.cn/sj/tjbz/tjyqhdmhcxhfdm/2023"

# 是否开启打印输出
ENABLE_PRINT = False

# 最大抓取深度，最抓取到哪一个层级的区域数据，总共5级
MAX_GRAB_LEVEL = 4

# 是否开启将数据写入到MySql
ENABLE_MYSQL_INSERTION = True

# 遇到列值为 null 时，是否跳过这条记录，继续向下执行
SKIP_NULL_COLUMN_VALUE = True

# 抓取网页出错时，是否继续向下执行
CONTINUE_ON_HTML_ERROR = True

# 抓取的最大数据条数，主要用于调代码，避免输出内容太多，负数代表抓取所有
MAX_GRAB_COUNT = -1

# 当前正在处理的省份，用于判断是否是直辖市
current_province_name = None

# 当前正在处理的城市名，用于判断提交MySql时，日志输出
current_city_name = None

# 连接MySql, 请根据实际情况调整参数
try:
    db = pymysql.connect(host='localhost', user='root', passwd='root', port=3306, db="my_db")
    cursor = db.cursor()
    print('连接Mysql成功！')
except:
    print('连接MySql失败')
    exit

def print_info(message:str):
    '''
    自定义一个内容输出方法，主要目的是可以统一控制是否输出，用于调试
    '''
    if ENABLE_PRINT:
        print(message) 

def insert_area_to_mysql(code:str, name:str, level:int, parent_code:str):
    '''
    插入一条记录到MySql，但不提交
    参数：
        code(str): 区域编码
        name(str): 区域名称
        level(int): 区域等级，
            1: 省/直辖市
            2: 市
            3: 区/县
            4: 乡镇/街道
            5: 社区/村委会
        parent_code(str): 父级编码
    '''
    if not ENABLE_MYSQL_INSERTION:
        return
    
    if code is None or name is None:
        print("发现null值：code={}, name={}, level={}, parent_code={}".format(code, name, level, parent_code))
        if SKIP_NULL_COLUMN_VALUE:
            return
        else:
            db.close()
            print("插入到MySql时遇到 Null 列值，程序将退出")
            exit()

    # 请根据实际情况调整表名
    sql = "insert into admin_area_2023(`code`, `name`, `level`, `parent_code`) values ('{}', '{}', {}, '{}')".format(code, name, level, parent_code)
    sql = sql.replace("'None'", 'NULL')
    print_info(sql)
    cursor.execute(sql)

def commit_for_mysql():
    global db, current_province_name
    try:
        db.commit()
        print("保存<{}·{}>行政区划数据到MySql成功".format(current_province_name, current_city_name))
    except Exception as e:
        db.rollback()
        print("保存" + current_province_name + "的行政区划数据到MySql失败")
        print(traceback.format_exc())

def get_admin_area_html(url:str):
    try_count = 0
    while try_count < 5:
        try_count += 1
        try:
            if try_count == 1:
                time.sleep(0.1)
            # 第一次抓取失败
            elif try_count == 2:
                time.sleep(1)
            else:
                time.sleep(2)

            response = requests.get(url)
            response.encoding = response.apparent_encoding
            return etree.HTML(response.text)
        except Exception:
            if try_count > 3:
                print(traceback.format_exc())
                print("连续 {} 次抓取 {} 页面时发生错误,。可能被服务怀疑是爬虫，拒绝了网络连接，因此休息10秒".format(try_count, url))
                time.sleep(10)
                return None
            else:
                print("第 {} 次抓取 {} 网页文本失败".format(try_count, url))

def grab_all_provinces():
    '''
    抓取所有省份
    '''
    html = get_admin_area_html(HOME_URL + "/index.html")
    province_nodes = html.xpath('//*/tr[@class="provincetr"]/td/a')

    grabed_count = 0
    for province_node in province_nodes:
        grabed_count += 1

        province_city_link = HOME_URL + "/" + province_node.attrib["href"]
        province_code = province_node.attrib["href"][0:2] + '0000000000'
        province_name = province_node.text.strip()
        global current_province_name
        current_province_name = province_name
        print_info("province_code={}, province_name={}".format(province_code, province_name))
        insert_area_to_mysql(province_code, province_name, 1, None)
        if MAX_GRAB_LEVEL >= 2:
            grab_province_cities(province_city_link, province_code, province_name)

        if MAX_GRAB_COUNT > 0 and grabed_count >= MAX_GRAB_COUNT:
            break

def grab_province_cities(province_city_link:str, province_code:str, province_name:str):
    '''
    抓取单个省/直辖市下的城市/区县
    参数:
        province_city_link(str): 省/直辖市区域页面的完整 url
        province_code(str): 城市所属的省份编码
        province_name(str): 城市所属的省份名称
    '''
    print("开始抓取省份（{}）的城市列表, URL={}".format(province_name, province_city_link))

    html = get_admin_area_html(province_city_link)
    if html is None:
        print("抓取省份（{}）的城市列表失败".format(province_name))
        return

    cityNodes = html.xpath('//*/tr[@class="citytr"]')

    grabed_count = 0
    global current_city_name
    for cityNode in cityNodes:
        link_nodes = cityNode.xpath('./*/a')
        city_code = link_nodes[0].text
        city_name = link_nodes[1].text.strip()
        current_city_name = city_name
        insert_area_to_mysql(city_code, city_name, 2, province_code)
        print_info("city_code={}, city_name={}".format(city_code, city_name))
        if MAX_GRAB_LEVEL >= 3 and link_nodes[1].attrib.has_key("href"):
            county_link = province_city_link[0:province_city_link.rfind('/')] + "/" + link_nodes[1].attrib["href"]
            grab_city_couties(county_link, city_code, city_name)

        # 以城市为最小提交单位
        commit_for_mysql()

        if MAX_GRAB_COUNT > 0 and grabed_count >= MAX_GRAB_COUNT:
            break

def grab_city_couties(city_county_link:str, city_code:str, city_name:str):
    '''
    抓取单个城市下的区/县
    参数:
        city_county_link(str): 城市区/县页面的完整 url
        city_code(str): 城市的编码
        city_name(str): 城市的名称
    '''
    print("开始抓取城市（{}）的区/县列表, URL={}".format(city_name, city_county_link))

    html = get_admin_area_html(city_county_link)
    if html is None:
        print("抓取城市（{}）的区/县列表失败".format(city_name))
        return
    
    county_nodes = html.xpath('//*/tr[@class="countytr"]')
    grabed_count = 0
    global current_province_name
    for county_node in county_nodes:
        grabed_count += 1
        county_link_nodes = county_node.xpath("./*/a")
        if len(county_link_nodes) == 0:
            # 没有<a>标签，通常是直辖市的市辖区，内容抓取方式不同
            county_code = county_node.xpath("./td")[0].text
            county_name = county_node.xpath("./td")[1].text
            insert_area_to_mysql(county_code, county_name, 3, city_code)
            print_info("county_code={}, county_name={}, parent_code={}".format(county_code, county_name, city_code))
        else:
            county_code = county_link_nodes[0].text
            county_name = county_link_nodes[1].text
            insert_area_to_mysql(county_code, county_name, 3, city_code)
            print_info("county_code={}, county_name={}, level=2, parent_code = {}".format(county_code, county_name, city_code))
            if MAX_GRAB_LEVEL >= 4 and county_link_nodes[1].attrib.has_key("href"):
                town_link = city_county_link[0:city_county_link.rfind("/")] + "/" + county_link_nodes[1].attrib["href"]
                grab_county_towns(town_link, county_code, county_name)
        
        if MAX_GRAB_COUNT > 0 and grabed_count >= MAX_GRAB_COUNT:
            break

def grab_county_towns(county_town_link:str, county_code:str, county_name:str):
    '''
    抓取单个区/县下的乡镇/街道
    参数:
        county_town_link(str): 乡镇/街道数据页面完整的 url
        county_code(str): 区/县的编码
        county_name(str): 区/县的名称
    '''
    print("开始抓取区县（{}）的街道/乡镇列表, URL={}".format(county_name, county_town_link))

    html = get_admin_area_html(county_town_link)
    if html is None:
        print("抓取区县（{}）的街道/乡镇列表失败".format(county_name))
        return
    
    town_nodes = html.xpath('//*/tr[@class="towntr"]')
    grabed_count = 0
    for town_node in town_nodes:
        grabed_count += 1
        village_link_nodes = town_node.xpath('./*/a')
        town_code = village_link_nodes[0].text
        town_name = village_link_nodes[1].text
        print_info("town_code={}, town_name={}".format(town_code, town_name))
        insert_area_to_mysql(town_code, town_name, 4, county_code)
        if MAX_GRAB_LEVEL >= 5 and village_link_nodes[1].attrib.has_key("href"):
            village_link = county_town_link[0:county_town_link.rfind("/")] + "/" + village_link_nodes[1].attrib["href"]
            grab_town_villages(village_link, town_code, town_name)

        if MAX_GRAB_COUNT > 0 and grabed_count >= MAX_GRAB_COUNT:
            break

def grab_town_villages(town_village_url:str, town_code:str, town_name:str):
    '''
    抓取单个街道/乡镇下的社区/村委会
    参数:
        town_village_url(str): 社区/村委会数据页面完整的 url
        town_code(str): 街道/乡镇的编码
        town_name(str): 街道/乡镇的名称
    '''
    print_info("开始抓取街道/乡镇下（{}）的社区/村委会列表, URL={}".format(town_name, town_village_url))

    html = get_admin_area_html(town_village_url)
    if html is None:
        print("抓取街道/乡镇下（{}）的社区/村委会列表失败".format(town_name))
        return
    
    village_nodes = html.xpath('//*/tr[@class="villagetr"]')
    grabed_count = 0
    for village_node in village_nodes:
        grabed_count += 1
        village_info_columns = village_node.xpath('./td')
        village_code = village_info_columns[0].text
        village_name = village_info_columns[2].text
        insert_area_to_mysql(village_code, village_name, 5, town_code)
        print_info("village_code={}, village_code={}".format(village_code, village_name))

        if MAX_GRAB_COUNT > 0 and grabed_count >= MAX_GRAB_COUNT:
            break

# 正式执行数据抓取任务
grab_all_provinces()

db.close()